Какой самый простой способ преобразовать тензор формы (batch_size, height, width), заполненный значениями n, в тензор формы (batch_size, n, height, width)? Я создал решение ниже, но похоже, что есть более простой и быстрый способ сделать это def batch_tensor_to_onehot (tnsr, классы): tnsr = tnsr.unsqueeze (1) res = [] для cls в диапазоне (классы): res.append ((tnsr == cls) .long ()) вернуть torch.cat (res, dim = 1)
2021-02-20 08:20:04
Вы можете использовать torch.nn.functional.one_hot. Для вашего случая: a = torch.nn.functional.one_hot (tnsr, num_classes = классы) out = a.permute (0, 3, 1, 2) | Вы также можете использовать Tensor.scatter_, который избегает .permute, но, возможно, труднее понять, чем простой метод, предложенный @Alpha. def batch_tensor_to_onehot (tnsr, классы): результат = torch.zeros (tnsr.shape [0], классы, * tnsr.shape [1:], dtype = torch.long, device = tnsr.device) result.scatter_ (1, tnsr.unsqueeze (1), 1) вернуть результат Результаты сравнительного анализа Мне было любопытно, и я решил протестировать три подхода. Я обнаружил, что между предлагаемыми методами нет существенной относительной разницы в отношении размера партии, ширины или высоты. В первую очередь отличительным фактором было количество классов. Конечно, как и в любом тесте, пробег может отличаться. Контрольные показатели были собраны с использованием случайных индексов и с использованием размера партии, высоты, ширины = 100. Каждый эксперимент повторяли 20 раз, сообщая среднее значение. Эксперимент num_classes = 100 запускается один раз перед профилированием для разминки. Результаты CPU показывают, что исходный метод, вероятно, лучше всего подходил для num_classes меньше 30, в то время как для GPU подход scatter_ кажется самым быстрым. Тесты выполнены на Ubuntu 18.04, NVIDIA 2060 Super, i7-9700K Код, используемый для сравнительного анализа, представлен ниже: импортный фонарик из tqdm импорт tqdm время импорта импортировать matplotlib.pyplot как plt def batch_tensor_to_onehot_slavka (tnsr, классы): tnsr = tnsr.unsqueeze (1) res = [] для cls в диапазоне (классы): res.append ((tnsr == cls) .long ()) вернуть torch.cat (res, dim = 1) def batch_tensor_to_onehot_alpha (tnsr, классы): результат = torch.nn.functional.one_hot (tnsr, num_classes = классы) вернуть result.permute (0, 3, 1, 2) def batch_tensor_to_onehot_jodag (tnsr, классы): результат = torch.zeros (tnsr.shape [0], классы, * tnsr.shape [1:], dtype = torch.long, device = tnsr.device) result.scatter_ (1, tnsr.unsqueeze (1), 1) вернуть результат def main (): num_classes = [2, 10, 25, 50, 100] высота = 100 ширина = 100 шс = [100] * 20 для d в ['cpu', 'cuda']: times_slavka = [] times_alpha = [] times_jodag = [] разминка = Правда для c в tqdm ([num_classes [-1]] + num_classes, ncols = 0): цлавка = 0 тальфа = 0 tjodag = 0 для b в bs: tnsr = torch.randint (c, (b, height, width)). to (device = d) t0 = время. время () y = batch_tensor_to_onehot_slavka (tnsr, c) torch.cuda.synchronize () цлавка + = time.time () - t0 если не разминка: times_slavka.append (цлавка / лен (бс)) для b в bs: tnsr = torch.randint (c, (b, height, width)). to (device = d) t0 = время. время () y = batch_tensor_to_onehot_alpha (tnsr, c) torch.cuda.synchronize () talpha + = time.time () - t0 если не разминка: times_alpha.append (talpha / len (bs)) для b в bs: tnsr = torch.randint (c, (b, height, width)). to (device = d) t0 = время. время () y = batch_tensor_to_onehot_jodag (tnsr, c) torch.cuda.synchronize () tjodag + = time.time () - t0 если не разминка: times_jodag.append (tjodag / len (bs)) разминка = Ложь fig = plt.figure () ax = fig.subplots () ax.plot (num_classes, times_slavka, label = 'Славка-кошка') ax.plot (num_classes, times_alpha, label = 'Alpha-one_hot') ax.plot (num_classes, times_jodag, label = 'jodag-scatter_') ax.set_xlabel ('количество_классов') ax.set_ylabel ('время (а)') ax.set_title (f '{d} benchmark') ax.legend () plt.savefig (f '{d} .png') plt.show () если __name__ == "__main__": главный() | Твой ответ StackExchange.ifUsing ("редактор", function () { StackExchange.using ("externalEditor", function () { StackExchange.using ("сниппеты", function () { StackExchange.snippets.init (); }); }); }, "фрагменты кода"); StackExchange.ready (функция () { var channelOptions = { теги: "" .split (""), id: "1" }; initTagRenderer ("". split (""), "" .split (""), channelOptions); StackExchange.using ("externalEditor", function () { // Должен запускать редактор после сниппетов, если сниппеты включены if (StackExchange.settings.snippets.snippetsEnabled) { StackExchange.using ("сниппеты", function () { createEditor (); }); } еще { createEditor (); } }); function createEditor () { StackExchange.prepareEditor ({ useStacksEditor: ложь, heartbeatType: 'ответ', autoActivateHeartbeat: ложь, convertImagesToLinks: правда, noModals: правда, showLowRepImageUploadWarning: true, РепутацияToPostImages: 10, bindNavPrevention: правда, постфикс: "", imageUploader: { brandingHtml: "На основе \ u003ca href = \" https: //imgur.com/ \ "\ u003e \ u003csvg class = \" svg-icon \ "width = \" 50 \ "height = \" 18 \ "viewBox = \ "0 0 50 18 \" fill = \ "none \" xmlns = \ "http: //www.w3.org/2000/svg \" \ u003e \ u003cpath d = \ "M46.1709 9.17788C46.1709 8.26454 46.2665 7.94324 47.1084 7.58816C47.4091 7.46349 47.7169 7.36433 48.0099 7.26993C48.9099 6.97997 49.672 6.73443 49.672 5.93063C49.672 5.22043 48.9832 4.61182 48.14144.61182C47.4335 4.61182 46.7256234.9516284.61182C47.4335 4.61182 46.725623.4916284,61182C47.4335 4.61182 46.725623.4916284,61182 43.1481 6.59048V11.9512C43.1481 13.2535 43.6264 13.8962 44.6595 13.8962C45.6924 13.8962 46.1709 13.253546.1709 11.9512V9.17788Z \ "/ \ u003e \ u003cpath d = \" M32.492 10.1419C32.492 12.6954 34.1182 14.0484 37.0451 14.0484C39.9723 14.0484 41.5985 12.6954 41.5985 10.1419V6.56162623221.5985 10.1419V6.56162623221.59839 41.106.590623221.59839 41.106.59062 38.5948 5.28821 38.5948 6.59049V9.60062C38.5948 10.8521 38.2696 11.5455 37.0451 11.5455C35.8209 11.5455 35.4954 10.8521 35.4954 9.60062V6.59049C35.4954 5.28821 35.0173 4.66232 34.0034703602905.28821 35.0173 4.66232 34.0034216290c fill-rule = \ "evenodd \" clip-rule = \ "evenodd \" d = \ "M25.6622 17.6335C27.8049 17.6335 29.3739 16.9402 30.2537 15.6379C30.8468 14.7755 30.9615 13.5579 30.9615 11.9512V6.5904921 30.4833 4.688 29.4502 4.66231C28.9913 4.66231 28.4555 4.94978 28.1109 5.50789C27.499 4.86533 26.7335 4.56087 25.7005 4.56087C23.1369 4.56087 21.0134 6.57349 21.0134 9.27932C21.0134 11.9852 23.003 13.9137.132832C21.0134 11.9852 23.003 13.9137.1326.916.2816.1328.916.1328.916.28.916.1328.48.28.916.1328.48 C28. 1256 12.8854 28,1301 12,9342 28,1301 12.983C28.1301 14,4373 27,2502 15,2321 25,777 15.2321C24.8349 15,2321 24,1352 14,9821 23,5661 14.7787C23.176 14,6393 22,8472 14,5218 22,5437 14.5218C21.7977 14,5218 21,2429 15,0123 21,2429 15.6887C21.2429 16,7375 22,9072 17,6335 25,6622 17.6335ZM24.1317 9,27932 C24.1317 7.94324 24.9928 7.09766 26.1024 7.09766C27.2119 7.09766 28.0918 7.94324 28.0918 9.27932C28.0918 10.6321 27.2311 11.5116 26.1024 11.5116C24.9737 11.5116 24.1317 10.6491 24.1317 9.279325C \ u003c \ dc \ "d \" \ "\" \ "\" \ " 8045 13,2535 17,2637 13,8962 18,2965 13,8962C19,3298 13,8962 19,8079 13,2535 19,8079 11,9512V8.12928C19.8079 5,82936 18,4879 4,62866 16,4027 4,62866C15,1594 4,62866 14,279 4,98375 13,3609 5,88013C4,62866 14,279 4,98375 13,3609 5,88013C4 28669 4,6669 5,5669 4,666 8,566 7,5669 4,6669 4,6669 4,6669 5,5669 8,566 7,5669 4,666 9 5,566 58314 4.9328 7.10506 4.66232 6.51203 4.66232C5.47873 4.66232 5.00066 5.28821 5.00066 6.59049V11.9512C5.00066 13.2535 5.47873 13.8962 6.51203 13.8962C7.54479 13.8962 8.0232 13 .2535 8.0232 11.9512V8.90741C8.0232 7.58817 8.44431 6.91179 9.53458 6.91179C10.5104 6.91179 10.893 7.58817 10.893 8.94108V11.9512C10.893 13.2535 11.3711 13.8962 12.4044 13.8962365C13.4171375 13.97157.89 13.97157.89 13.97157.89 9 C16.4027 6.91179 16.8045 7.58817 16.8045 8.94108V11.9512Z \ "/> \ u003cpath d = \" M3.31675 6.59049C3.31675 5.28821 2.83866 4.66232 1.82471 4.66232C0.791758 4.66232 0.313354 5.28821 0.313354 6.5 1.82471 13.8962C2.85798 13.8962 3.31675 13.2535 3.31675 11.9512V6.59049Z \ "/> \ u003cpath d = \" M1.87209 0.400291C0.843612 0.400291 0 1.1159 0 1.98861C0 2.87869 0.822846208673.57007 3.6861C0 2.87869 0.822846208673.57007 3.85673.57869 0.822841 C3.7234 1.1159 2.90056 0.400291 1.87209 0.400291Z \ "fill = \" # 1BB76E \ "/> \ u003c / svg> \ u003c / a>", contentPolicyHtml: "Вклады пользователей под лицензией \ u003ca href = \" https: //stackoverflow.com/help/licensing \ "\ u003ecc by-sa \ u003c / a> \ u003ca href = \" https://stackoverflow.com / legal / content-policy \ "> (политика содержания) \ u003c / a>", allowUrls: true }, onDemand: правда, discardSelector: ".discard-answer" , немедленноShowMarkdownHelp: true, enableTables: true, enableSnippets: true }); } }); Спасибо за ответ на Stack Overflow! Обязательно ответьте на вопрос. Предоставьте подробную информацию и поделитесь своим исследованием! Но избегайте… Просить о помощи, разъяснениях или отвечать на другие ответы. Делать заявления, основанные на мнении; подкрепите их рекомендациями или личным опытом. Чтобы узнать больше, ознакомьтесь с нашими советами по написанию отличных ответов. Черновик сохранен Черновик отклонен Зарегистрируйтесь или войдите под своим ником StackExchange.ready (функция () { StackExchange.helpers.onClickDraftSave ('# ссылка для входа'); }); Зарегистрируйтесь с помощью Google Зарегистрируйтесь через Facebook Зарегистрируйтесь, используя электронную почту и пароль Представлять на рассмотрение Опубликовать как гость Имя Электронное письмо Обязательно, но не отображается StackExchange.ready ( function () { StackExchange.openid.initPostLogin ('. New-post-login', 'https% 3a% 2f% 2fstackoverflow.com% 2fquestions% 2f62245173% 2fpytorch-transform-tensor-to-one-hot% 23new-answer', 'question_page' ); } ); Опубликовать как гость Имя Электронное письмо Обязательно, но не отображается Разместите свой ответ Отказаться Нажимая «Опубликовать ответ», вы соглашаетесь с нашими условиями обслуживания, политикой конфиденциальности и политикой использования файлов cookie. Не тот ответ, который вы ищете? Посмотрите другие вопросы с метками python pytorch tenor one-hot-encoding или задайте свой вопрос.